"""
Routines anc Classes to manipulate Octopus Inputs
"""

import os
import gzip
import numpy as np
import pychemia
from ..codes import CodeInput


class OctopusInput(CodeInput):
    """
    Manipulate an octopus input file
    """

    def __init__(self, input_file='inp'):
        """
        Converts a given octopus input file
        into a dictionary where the keys are
        variable names and the values are scalars
        or list if the variable is a block
        """
        CodeInput.__init__(self)
        key = ''
        self.variables = {}
        self.input_file = input_file
        if os.path.isfile(input_file):
            self.read()

    def read(self):
        rfile = open(self.input_file)

        multivalue = False
        values = None
        for line in rfile.readlines():

            if multivalue:
                if line.strip()[0] == '%' and len(line.strip()) == 1:
                    self.variables[key] = values
                    multivalue = False
                    values = None
                elif values is None:
                    values = [x.strip() for x in line.split('|')]
                else:
                    if type(values[0]) != list:
                        values = [values]
                    values.append([x.strip() for x in line.split('|')])
            elif line == '\n':
                continue
            elif line.strip()[0] == '#':
                continue
            elif len(line.split('=')) == 2:
                key = line.split('=')[0].strip()
                value = line.split('=')[1].strip()
                self.variables[key] = value
            elif line.strip()[0] == '%' and len(line.strip()) > 1:
                key = line.strip()[1:]
                multivalue = True
            else:
                print('Line not parsed:', line)

        rfile.close()

    def write(self, filename):
        """
        Write an input dictionary into a file
        the variables are sort by kind and written
        in their respective place

        :param filename: (str) Filename for the octopus input that will be created
        """
        wfile = open(filename, 'w')
        wfile.write(self.__str__())
        wfile.close()

    def __str__(self):
        """
        Creates an string representation of
        the input as it will be written
        in a file
        """
        # Get all variables and their groups
        oct_vars = get_vars()
        # All known groups of variables
        oct_groups = list(oct_vars.keys())

        octvars = list(self.variables.keys())

        oct_str = ""
        # Writing the known groups
        for igroup in sorted(oct_groups):
            use = 0
            for j in sorted(oct_vars[igroup]):
                if j in octvars:
                    if use == 0:
                        oct_str += '\n#' + 60 * '-' + '\n'
                        oct_str += '#' + 3 * ' ' + ' ' + igroup + '\n'
                        oct_str += '#' + 60 * '-' + '\n\n'
                        use = 1

                    oct_str += _write_key(j, self.variables[j])
                    # Remove the variable from the list when written
                    octvars.remove(j)

        oct_def = ""
        if len(octvars) > 0:
            # Writing all remaining variables that has not group associated
            oct_def += '\n#' + 60 * '-' + '\n'
            oct_def += '#' + 3 * ' ' + ' ' + 'Definitions' + '\n'
            oct_def += '#' + 60 * '-' + '\n\n'
            print('Non grouped variables:')

            for j in octvars:
                oct_def += _write_key(j, self.variables[j])
                print(' * ', j)

        return oct_def + oct_str


def execute(basedir, num_threads=1, num_procs=2):
    """
    Call octopus in parallel with a given number of threads and
    MPI process

    :param basedir: (str) Working directory for execution
    :param num_threads: (int) Number of OpenMP Threads to use
    :param num_procs: (int) Number of MPI processes to create
    """
    wfile = open('script.sh', 'w')
    wfile.write('export OMP_NUM_THREADS=' + str(num_threads) + '\n')
    wfile.write('mpirun -np ' + str(num_procs) + ' octopus_mpi')
    wfile.close()
    pychemia.runner.execute(basedir, 'bash', 'script.sh')


def get_vars():
    """
    Get a list of all the variables in octopus
    """
    data = open(os.path.dirname(pychemia.__path__[0]) + '/pychemia/code/octopus/octopus_variables.conf')
    oct_variables = {}

    for line in data.readlines():
        line = line.strip()
        if len(line) == 0:
            continue
        elif line[0] == '[':
            section = line.strip()[1:-1]
        elif len(line.strip()) > 0:
            if section not in oct_variables.keys():
                oct_variables[section] = [line]
            else:
                oct_variables[section].append(line)
    return oct_variables


class OpenDX:
    """
    Manipulate the OpenDX file generated by octopus
    """

    def __init__(self, filename):
        """
        Creates an OpenDX file generated for octopus
        """
        rfile = gzip.open(filename)
        data = rfile.readlines()
        rfile.close()
        del rfile
        self.nsize = np.array([int(x) for x in data[0].split()[-3:]])
        self.origin = np.array([float(x) for x in data[1].split()[-3:]])
        self.delta = np.array([[float(x.split()[i]) for i in [1, 2, 3]] for x in data[2:5]])
        self.field = np.array(map(float, data[7:np.prod(self.nsize) + 7]))
        self.field = self.field.reshape(tuple(self.nsize))
        del data

    def integrate_x(self):
        """
        Compute the Integral in the X direction
        """
        xvalue = np.arange(self.origin[0], self.origin[0] + self.delta[0, 0] * self.nsize[0], self.delta[0, 0])
        yvalue = np.sum(np.sum(self.field, axis=1), axis=1)
        return xvalue, yvalue


def _write_key(key, value):
    """
    Write in the proper format the key and the value in
    a octopus input file
    """
    oct_str = ""
    keylen = 20

    if isinstance(value, int) or isinstance(value, float):
        oct_str = oct_str + (key.ljust(keylen) + " = " + str(value) + '\n')
    elif isinstance(value, str):
        if (key == 'XYZCoordinates' or key == 'XYZVelocities') and '"' not in value:
            oct_str = oct_str + (key.ljust(keylen) + ' = "' + value + '"\n')
        else:
            oct_str = oct_str + (key.ljust(keylen) + " = " + value + '\n')

    elif isinstance(value, list):
        oct_str += '\n%' + key + "\n"
        for itable in range(len(value)):
            if isinstance(value[itable], list):
                for jtable in range(len(value[itable])):
                    oct_str += str(value[itable][jtable]).ljust(10)

                    if jtable < len(value[itable]) - 1:
                        oct_str += " | "
                    if jtable == len(value[itable]) - 1 and itable < len(value) - 1:
                        oct_str += "\n"
            else:
                oct_str += str(value[itable])
                if itable < len(value) - 1:
                    oct_str += " | "
        oct_str += "\n%\n\n"

    return oct_str
